#################################################################################
# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################

"""
Reference:

FlowNet: Learning Optical Flow with Convolutional Networks
Philipp Fischer, Alexey Dosovitskiy, Eddy Ilg, Philip Häusser, Caner Hazırbaş, Vladimir Golkov,
Patrick van der Smagt, Daniel Cremers, Thomas Brox,
https://arxiv.org/abs/1504.06852

Note: FlowNetS is as explainhed in https://arxiv.org/pdf/1504.06852.pdf
Note: FlowNetLite is our flavour, that uses depthwise separable convolutions -
but structure is same as FlowNetS
"""



import torch
from ... import xnn

__all__ = [
    'flownets', 'flownetslite', 'get_config'
]

###################################################
def get_config():
    model_config = xnn.utils.ConfigNode()
    model_config.input_channels = 3
    model_config.num_classes = None
    model_config.output_type = None
    return model_config


#######################################################################################################
def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1, conv_type=None):
    if conv_type == 'regular':
        normalization = True if batchNorm else False
        return xnn.layers.ConvNormAct2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, bias=False, normalization=normalization)
    elif conv_type == 'depthwise':
        normalization = (True,True) if batchNorm else (False,False)
        return xnn.layers.ConvDWSepNormAct2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, bias=False, normalization=normalization)
    else:
        assert False, 'invalid convolution type'


def predict_flow(in_planes, out_chan):
    return xnn.layers.ConvLayer2d(in_planes, out_chan, kernel_size=1)


def upsample(in_planes, out_planes, size=None, scale_factor=2, upsample_mode="deconv", linear=False):
    layers = []
    if upsample_mode == 'deconv':
        layers += [torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=scale_factor, padding=1, bias=False)]
        if not linear:
            layers += [torch.nn.LeakyReLU(0.1,inplace=True)]
        #
    elif upsample_mode == "bilinear":
        layers += [torch.nn.Upsample(size=size, scale_factor=scale_factor, mode=upsample_mode)]
        if not linear:
            layers += [xnn.layers.ConvNormAct2d(in_planes, out_planes, kernel_size=1)]
        #
    else:
        assert False, 'invalid upsample_mode'

    return torch.nn.Sequential(*layers)


def crop_like(input, target):
    if input.size()[2:] == target.size()[2:]:
        return input
    else:
        return input[:, :, :target.size(2), :target.size(3)]


#######################################################################################################
class FlowNetS(torch.nn.Module):
    def __init__(self, input_channels=6, batch_norm=True, kernel_size1=7, output_channels=2, upsample_mode='deconv', conv_type='regular', last_stride=2):
        super().__init__()
        self.output_channels = sum(output_channels) if isinstance(output_channels, (list,tuple)) else output_channels
        self.upsample_mode = upsample_mode
        self.conv_type = conv_type
        self.last_channel = 1024
        self.batch_norm = batch_norm
        self.last_stride = last_stride  # last stride as per flownet is 2, but quality may be better with 1
        self.deep_supervision = False   # batch norm (and may be sufficient pretraining)
                                        # may be required if we want to train without deep supervision

        self.conv1   = conv(self.batch_norm,  input_channels,   64, kernel_size=kernel_size1, stride=2, conv_type='regular')
        self.conv2   = conv(self.batch_norm,  64,  128, kernel_size=5, stride=2, conv_type=self.conv_type)
        self.conv3   = conv(self.batch_norm, 128,  256, kernel_size=5, stride=2, conv_type=self.conv_type)
        self.conv3_1 = conv(self.batch_norm, 256,  256, conv_type=self.conv_type)
        self.conv4   = conv(self.batch_norm, 256,  512, stride=2, conv_type=self.conv_type)
        self.conv4_1 = conv(self.batch_norm, 512,  512, conv_type=self.conv_type)
        self.conv5   = conv(self.batch_norm, 512,  512, stride=2, conv_type=self.conv_type)
        self.conv5_1 = conv(self.batch_norm, 512,  512, conv_type=self.conv_type)
        self.conv6   = conv(self.batch_norm, 512, 1024, stride=self.last_stride, conv_type=self.conv_type)
        self.conv6_1 = conv(self.batch_norm,1024, self.last_channel, conv_type=self.conv_type)

        self.deconv5 = upsample(self.last_channel, 512, upsample_mode=self.upsample_mode)
        self.deconv4 = upsample(1024+self.output_channels, 256, upsample_mode=self.upsample_mode)
        self.deconv3 = upsample(768+self.output_channels, 128, upsample_mode=self.upsample_mode)
        self.deconv2 = upsample(384+self.output_channels, 64, upsample_mode=self.upsample_mode)

        self.predict_flow6 = predict_flow(1024, self.output_channels)
        self.predict_flow5 = predict_flow(1024+self.output_channels, self.output_channels)
        self.predict_flow4 = predict_flow(768+self.output_channels, self.output_channels)
        self.predict_flow3 = predict_flow(384+self.output_channels, self.output_channels)
        self.predict_flow2 = predict_flow(192+self.output_channels, self.output_channels)

        self.upsampled_flow6_to_5 = upsample(self.output_channels, self.output_channels, upsample_mode=self.upsample_mode, linear=True)
        self.upsampled_flow5_to_4 = upsample(self.output_channels, self.output_channels, upsample_mode=self.upsample_mode, linear=True)
        self.upsampled_flow4_to_3 = upsample(self.output_channels, self.output_channels, upsample_mode=self.upsample_mode, linear=True)
        self.upsampled_flow3_to_2 = upsample(self.output_channels, self.output_channels, upsample_mode=self.upsample_mode, linear=True)

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
                torch.nn.init.kaiming_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, torch.nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        x = torch.cat(x, dim=1) if isinstance(x, (list,tuple)) else x
        h, w = x.size()[-2:]

        out_conv2 = self.conv2(self.conv1(x))
        out_conv3 = self.conv3_1(self.conv3(out_conv2))
        out_conv4 = self.conv4_1(self.conv4(out_conv3))
        out_conv5 = self.conv5_1(self.conv5(out_conv4))
        out_conv6 = self.conv6_1(self.conv6(out_conv5))

        flow6       = self.predict_flow6(out_conv6)
        flow6_up    = crop_like(self.upsampled_flow6_to_5(flow6), out_conv5)
        out_deconv5 = crop_like(self.deconv5(out_conv6), out_conv5)

        concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
        flow5       = self.predict_flow5(concat5)
        flow5_up    = crop_like(self.upsampled_flow5_to_4(flow5), out_conv4)
        out_deconv4 = crop_like(self.deconv4(concat5), out_conv4)

        concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
        flow4       = self.predict_flow4(concat4)
        flow4_up    = crop_like(self.upsampled_flow4_to_3(flow4), out_conv3)
        out_deconv3 = crop_like(self.deconv3(concat4), out_conv3)

        concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
        flow3       = self.predict_flow3(concat3)
        flow3_up    = crop_like(self.upsampled_flow3_to_2(flow3), out_conv2)
        out_deconv2 = crop_like(self.deconv2(concat3), out_conv2)

        concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1)
        flow2 = self.predict_flow2(concat2)

        flow2 = torch.nn.functional.interpolate(flow2, (h, w), mode='nearest')

        if self.training and (self.deep_supervision == True):
            return [flow2, flow3, flow4, flow5, flow6]
        else:
            return [flow2]


def flownets(model_config, pretrained):
    """FlowNetS model architecture from the
    "Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852)

    Args:
        data : pretrained weights of the network. will create a new one if not set
    """
    model = FlowNetS(batch_norm=False, input_channels=model_config.input_channels[0], output_channels=model_config.output_channels[0])
    model = xnn.utils.load_weights(model, pretrained)
    return model


def flownetslite(model_config, pretrained):
    """FlowNetS Lite model architecture - a custom lite version"""
    model = FlowNetS(batch_norm=True, kernel_size1=3, input_channels=model_config.input_channels[0], output_channels=model_config.output_channels[0],
                     upsample_mode='bilinear', conv_type='depthwise', last_stride=1)
    model = xnn.utils.load_weights(model, pretrained)
    return model

